Skip to content

Conversation

charithaintc
Copy link
Contributor

@charithaintc charithaintc commented Aug 26, 2025

This PR adds the features needed for supporting the GEMM with transpose B case.

Summary of changes.

1). Add distribution logic for vector.bitcast, vector.transpose and memref.extract_aligned_pointer_as_index cases.
2). Add layout propagation support for vector.shape_cast, vector.broadcast and vector.bitcast
3). Incorporate slice attribute and DistributeLayoutAttr interface with the core logic in layout prop.

@charithaintc charithaintc requested review from Jianhui-Li, adam-smnk, chencha3 and silee2 and removed request for chencha3 August 26, 2025 23:27
@charithaintc charithaintc changed the title [mlir][xegpu] Add SIMT distribution support GEMM transpose B case. [mlir][xegpu] Add SIMT distribution support for GEMM transpose B case. Aug 26, 2025
// communication. So each lane must own the required number of elements to
// perform the bitcast locally without cross-lane communication.
int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
if (outInnerBitsPerLane < inElemTyBitWidth) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check the condition
srcInnerBitsPerLane = inElemTypeBitWidth x sourceLayout.getLaneData
if (outInnerBitsPerLane != srcInnerBitsPerLane)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this again. sourceLayout.getLaneData is not available to us because we are trying to decide this here. I think we can only detect narrowing case only.

Widening case will always be valid because at this point if result already have a valid layout. Otherwise it means that result was not assigned a correct layout. That must be concern of the layout conflict maybe.

In any case, I added a check to verify if the result layout is valid and can be distributed to lanes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I would move the check after the sourceLaneData is assigned. See comments below also.

shapeCast.emitWarning("Expecting result type to be 1D or 2D vector.");
return;
}
// For 2D -> 2D shape cast, propagate the result layout to the source.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider the restriction for now:

  1. same rank shape cast not allowed,
  2. always expand the dim not squeeze the dim,
  3. The new dims must be 1, and the original dims must not change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed I also added this condition for now.

  1. Result layout can not be a slice layout and it must have same rank as result.

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually smaller PRs make reviews go faster but I'll bite 😉

Overall logic looks good, only minor comments.

for (int64_t idx : permutation) {
newLayout.layout.push_back(laneLayout.layout[idx]);
newData.layout.push_back(laneData.layout[idx]);
laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about add one more utilit to layout attribute, like getTransposedLayout(), so that it can be reused by sg_layout, or lane_layout.
Potentially, the isTransposeOf can be simplified to doing a transpose of input and compare whether they are same?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree. I will add this in a separate PR and clean up.

func.func @vector_shape_cast_2d_to_1d_dim0_distributed(%arg0: !xegpu.tensor_desc<16x1xf16>, %arg1: !xegpu.tensor_desc<16xf16>) {
%c0 = arith.constant 0 : index
%3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x1xf16> -> vector<16x1xf16>
%2 = vector.shape_cast %3 : vector<16x1xf16> to vector<16xf16>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems contradict with the documentation
2) Shape cast must always expand the rank (e.g. 1D -> 2D).

and the code
https://github.com/llvm/llvm-project/pull/155517/files#diff-fcc9cdbf8bb4e5d37e661524b877082aee9b7badb0317f980c1881da564a926dR536

Not sure why the code is passing. Maybe I missed something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry. I forgot to remove this test (CI was failing because of it). I removed this tests now.

Copy link
Contributor

@adam-smnk adam-smnk Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shape cast must always expand the rank (e.g. 1D -> 2D).

If you refer to vector.shape_cast, a cast must preserve the same number of elements. Shape's rank can be freely changed up or down.

The two cases looked valid, it'd be good to understand why they failed.
If they can't be distributed, I'd leave them in as negative examples.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adam-smnk The restriction is there because we do not expect (for now) any narrowing shape casts. Shape cast is currently used to make the vector 2D after a 2D -> 1D reduction.

Adding back the tests as negative examples for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my bad. pass is designed to fail if we can not assign a proper layout to ops. So I can not add the negative example in the same file AFAIK.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, then it's sth rethink if it impacts testing.
A separate test file would be fine as this one's already pretty large. Not sure if verify-diagnostics can also test pass failures. TBD

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if verify-diagnostics can also test pass failures.

I think it can. challenge is doing it in same file. I did not find any examples. But I will give a try.

return;
}
// Decide lane data based on whether the bitcast is narrowing or widening.
int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For narrowing bitcast, innerMostLaneData = outData[rank - 1] * bitCastRatio, instead of / bitCastRatio?
Put a TODO here?: check the layout conflict case here if ( innerMostLaneData * inElemTyBitWidth != outInnerBitsPerLane ).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For narrowing bitcast, innerMostLaneData = outData[rank - 1] * bitCastRatio, instead of / bitCastRatio?

This is because in narrowing case source had higher bitwidth (e.g f32 -> f16)

Put a TODO here?: check the layout conflict case here if ( innerMostLaneData * inElemTyBitWidth != outInnerBitsPerLane ).

This is not required. At this point of layout propagation result layout is already a valid layout. We chose innerMostLaneData such that innerMostLaneData * inElemTyBitWidth == outInnerBitsPerLane.

// communication. So each lane must own the required number of elements to
// perform the bitcast locally without cross-lane communication.
int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
if (outInnerBitsPerLane < inElemTyBitWidth) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I would move the check after the sourceLaneData is assigned. See comments below also.

@charithaintc
Copy link
Contributor Author

@adam-smnk Can you take another look and/or approve? :-)

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@charithaintc charithaintc merged commit 2998c74 into llvm:main Sep 19, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants